import gc
import json
import random
import traceback
import time
from datetime import date


from colorama import Fore
from playwright.sync_api import sync_playwright

from dataModel import JJs, Stocks, db


class Now_JJ:
    def __init__(self, code, name, url, day, net):
        self.code = code
        self.name = name
        self.url = url
        self.day = day
        self.net = net


class Spider:
    "获取天天基金的爬虫"

    def __init__(self) -> None:
        self.year = date.today().year
        self.today = str(date.today())
        self.errors = []
        self.jjs = []
        self.cnt = 1

    def get_jj_list(self):
        "获取基金列表"
        p = self.p
        p.goto("http://fund.eastmoney.com/data/fundranking.html")
        p.click("#types > li:nth-child(2)")
        self.jSelect = JJs.select()
        self.sSelect = Stocks.select()
        while True:
            print(f"获取第{self.cnt}页的基金列表。")
            jj = p.query_selector_all("#dbtable > tbody > tr")
            for i in jj:
                code = i.query_selector("td:nth-child(3)").text_content()
                name = i.query_selector("td:nth-child(4)").text_content()
                url = i.query_selector(
                    "td:nth-child(4) > a").get_attribute("href")
                day = i.query_selector("td:nth-child(5)").text_content()
                net = i.query_selector("td:nth-child(6)").text_content()
                self.jjs.append(Now_JJ(code, name, url, day, net))
                print(code, name, url, day, net)
            print(f"第{self.cnt}页的基金列表获取完毕。")
            self.cnt += 1
            # break
            if not self.nextPage():
                break

    def nextPage(self):
        "翻页"
        print(f"{Fore.BLUE}翻页中。。。{Fore.BLACK}")
        self.p.click('label:has-text("下一页")')
        print("-" * 100)
        nextp = self.p.query_selector('label:has-text("下一页")')
        if nextp.get_attribute("class") == "end":
            return 0
        else:
            return 1

    def try_get(self, jj, n=3):
        """
        尝试n次爬取数据
        """
        name = jj.name
        code = jj.code
        suc = 0  # 假设没成功
        per = self.cnt / len(self.jjs)
        per *= 100
        s = "=" * int(per+1) + f"> {per+1:.2f}%"
        print(s)
        for k in range(n):
            print(f"第{k+1}次尝试{Fore.RED}{name}({code}){Fore.BLACK}数据爬取。。。")
            try:
                self.getJJData(jj)
                suc = 1
                break
            except:
                traceback.print_exc()
                print(f"{Fore.RED}{name}({code}){Fore.BLACK}数据获取失败，开始重试。。。")
            k += 1
        if suc == 0:
            self.errors.append(jj)
        self.cnt += 1

    def getJJData(self, jj):
        "获取单只基金数据"
        sp = self.p
        jjname = jj.name
        code = jj.code
        url = jj.url
        if not url.startswith("http"):
            url = "http://fund.eastmoney.com/" + url
        day = jj.day
        day = f"{self.year}-{day}"
        day = [int(i) for i in day.split("-")]
        day = date(day[0], day[1], day[2])
        if day != self.today:
            day = self.today
        net = jj.net
        query = self.jSelect.where(
            (JJs.code == code),
            (JJs.date == day),
            # (JJs.date != date(2022, 2, 7)),
        )
        if query:
            print(f"{Fore.RED}{jjname}({code}){Fore.BLACK}最新数据已存在，本次不爬取。")
        else:
            print(f"{Fore.RED}{jjname}({code}){Fore.BLACK}最新数据不存在，开始爬取...")
            sp.goto(url)
            money = sp.query_selector('td:has-text("基金规模")')
            money = money.text_content()
            money = money.split("：")[1]
            money = money.split("亿")[0]
            try:
                money = float(money)
            except:
                money = 0
            stocks = sp.query_selector_all(
                "#position_shares > div.poptableWrap > table > tbody > tr"
            )
            stocksData = []
            if len(stocks) > 2 and money:
                for s in stocks[1:]:
                    name = s.query_selector("td:nth-child(1)").text_content()
                    name = name.replace(" ", "")
                    per = s.query_selector("td:nth-child(2)")
                    per = per.text_content()
                    per = per.split("%")[0]
                    per = float(per)
                    per /= 100
                    sMoney = money * per
                    stocksData.append([name, per, sMoney])
                    sdata = self.sSelect.where(
                        (Stocks.name == name), (Stocks.date == day)
                    )
                    if sdata:
                        sdata = sdata[0]
                        sdata.n += 1
                        sdata.money += sMoney
                    else:
                        sdata = Stocks(
                            name=name,
                            n=1,
                            money=sMoney,
                            date=day,
                        )
                    sdata.save()
            else:
                print(f"{Fore.RED}{jjname}({code}){Fore.BLACK}没有公布持仓！")
            stocks = json.dumps(stocksData, ensure_ascii=False)
            # print(stocks)
            data = JJs(
                code=code,
                name=jjname,
                url=url,
                money=money,
                stocks=stocks,
                date=day,
                net=net,
            )
            data.save()
        db.close()
        db.connect()
        print(f"{Fore.GREEN}{jjname}({code}){Fore.BLACK}爬取完毕！")
        print("#" * 50)
        gc.collect()  # 每爬一次回收一次内存

    def rand_get(self):
        """
        随机获取
        """
        random.shuffle(self.jjs)  # 洗牌，增强随机性，防止被ban
        self.cnt = 0
        for jj in self.jjs:
            self.try_get(jj)
            time.sleep(random.random()+0.5)

        for jj in self.errors:
            print(jj.code, jj.name)

    def run(self):
        "运行"
        with sync_playwright() as p:
            b = p.chromium.launch(
                headless=True,
                slow_mo=1000,
            )
            p = b.new_page()
            self.b = b
            self.p = p
            self.get_jj_list()
            self.rand_get()
            self.p.close()
        # db.connect()
        # sdata=self.sSelect.where((Stocks.date == self.day))


if __name__ == "__main__":
    s = Spider()
    s.run()
